{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 优化器"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "\n",
    "# 设置权重，服从正态分布  --> 2 x 2\n",
    "weight = torch.randn((2, 2), requires_grad=True)\n",
    "# 设置梯度为全1矩阵  --> 2 x 2\n",
    "weight.grad = torch.ones((2, 2))\n",
    "# 输出现有的weight和data\n",
    "print(\"The data of weight before step:\\n{}\".format(weight.data))\n",
    "print(\"The grad of weight before step:\\n{}\".format(weight.grad))\n",
    "# 实例化优化器\n",
    "optimizer = torch.optim.SGD([weight], lr=0.1, momentum=0.9)\n",
    "# 进行一步操作\n",
    "optimizer.step()\n",
    "# 查看进行一步后的值，梯度\n",
    "print(\"The data of weight after step:\\n{}\".format(weight.data))\n",
    "print(\"The grad of weight after step:\\n{}\".format(weight.grad))\n",
    "# 权重清零\n",
    "optimizer.zero_grad()\n",
    "# 检验权重是否为0\n",
    "print(\"The grad of weight after optimizer.zero_grad():\\n{}\".format(weight.grad))\n",
    "# 输出参数\n",
    "print(\"optimizer.params_group is \\n{}\".format(optimizer.param_groups))\n",
    "# 查看参数位置，optimizer和weight的位置一样，我觉得这里可以参考Python是基于值管理\n",
    "print(\"weight in optimizer:{}\\nweight in weight:{}\\n\".format(id(optimizer.param_groups[0]['params'][0]), id(weight)))\n",
    "# 添加参数：weight2\n",
    "weight2 = torch.randn((3, 3), requires_grad=True)\n",
    "optimizer.add_param_group({\"params\": weight2, 'lr': 0.0001, 'nesterov': True})\n",
    "# 查看现有的参数\n",
    "print(\"optimizer.param_groups is\\n{}\".format(optimizer.param_groups))\n",
    "# 查看当前状态信息\n",
    "opt_state_dict = optimizer.state_dict()\n",
    "print(\"state_dict before step:\\n\", opt_state_dict)\n",
    "# 进行5次step操作\n",
    "for _ in range(50):\n",
    "    optimizer.step()\n",
    "# 输出现有状态信息\n",
    "print(\"state_dict after step:\\n\", optimizer.state_dict())\n",
    "# 保存参数信息\n",
    "torch.save(optimizer.state_dict(),os.path.join(r\"D:\\pythonProject\\Attention_Unet\", \"optimizer_state_dict.pkl\"))\n",
    "print(\"----------done-----------\")\n",
    "# 加载参数信息\n",
    "state_dict = torch.load(r\"D:\\pythonProject\\Attention_Unet\\optimizer_state_dict.pkl\") # 需要修改为你自己的路径\n",
    "optimizer.load_state_dict(state_dict)\n",
    "print(\"load state_dict successfully\\n{}\".format(state_dict))\n",
    "# 输出最后属性信息\n",
    "print(\"\\n{}\".format(optimizer.defaults))\n",
    "print(\"\\n{}\".format(optimizer.state))\n",
    "print(\"\\n{}\".format(optimizer.param_groups))"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
